{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Word vectors from SEC filings using Gensim: word2vec model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports & Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:52.650272Z",
     "start_time": "2020-06-21T14:57:52.648417Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.385881Z",
     "start_time": "2020-06-21T14:57:52.651821Z"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from time import time\n",
    "import logging\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from gensim.models import Word2Vec, KeyedVectors\n",
    "from gensim.models.word2vec import LineSentence\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.390272Z",
     "start_time": "2020-06-21T14:57:53.387016Z"
    }
   },
   "outputs": [],
   "source": [
    "sns.set_style('white')\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.399615Z",
     "start_time": "2020-06-21T14:57:53.391212Z"
    }
   },
   "outputs": [],
   "source": [
    "def format_time(t):\n",
    "    m, s = divmod(t, 60)\n",
    "    h, m = divmod(m, 60)\n",
    "    return f'{h:02.0f}:{m:02.0f}:{s:02.0f}'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.416919Z",
     "start_time": "2020-06-21T14:57:53.400799Z"
    }
   },
   "outputs": [],
   "source": [
    "sec_path = Path('..', 'data', 'sec-filings')\n",
    "ngram_path = sec_path / 'ngrams'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.426072Z",
     "start_time": "2020-06-21T14:57:53.418180Z"
    }
   },
   "outputs": [],
   "source": [
    "results_path = Path('results', 'sec-filings')\n",
    "\n",
    "model_path = results_path / 'models'\n",
    "if not model_path.exists():\n",
    "    model_path.mkdir(parents=True)\n",
    "\n",
    "log_path = results_path / 'logs'\n",
    "if not log_path.exists():\n",
    "    log_path.mkdir(parents=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logging Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.437331Z",
     "start_time": "2020-06-21T14:57:53.427249Z"
    }
   },
   "outputs": [],
   "source": [
    "logging.basicConfig(\n",
    "    filename=log_path / 'word2vec.log',\n",
    "    level=logging.DEBUG,\n",
    "    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',\n",
    "    datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:33:56.321672Z",
     "start_time": "2020-06-21T15:33:56.319833Z"
    }
   },
   "outputs": [],
   "source": [
    "analogies_path = Path('data', 'analogies-en.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up Sentence Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.453489Z",
     "start_time": "2020-06-21T14:57:53.446620Z"
    }
   },
   "outputs": [],
   "source": [
    "NGRAMS = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To facilitate memory-efficient text ingestion, the LineSentence class creates a generator from individual sentences contained in the provided text file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T14:57:53.478441Z",
     "start_time": "2020-06-21T14:57:53.454500Z"
    }
   },
   "outputs": [],
   "source": [
    "sentence_path = ngram_path  / f'ngrams_{NGRAMS}.txt'\n",
    "sentences = LineSentence(sentence_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train word2vec Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [gensim.models.word2vec](https://radimrehurek.com/gensim/models/word2vec.html) class implements the skipgram and CBOW architectures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:58.507035Z",
     "start_time": "2020-06-21T14:57:53.479312Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Duration: 00:27:05\n"
     ]
    }
   ],
   "source": [
    "start = time()\n",
    "model = Word2Vec(sentences,\n",
    "                 sg=1,          # 1 for skip-gram; otherwise CBOW\n",
    "                 hs=0,          # hierarchical softmax if 1, negative sampling if 0\n",
    "                 size=300,      # Vector dimensionality\n",
    "                 window=5,      # Max distance betw. current and predicted word\n",
    "                 min_count=50,  # Ignore words with lower frequency\n",
    "                 negative=15,    # noise word count for negative sampling\n",
    "                 workers=4,     # no threads \n",
    "                 iter=1,        # no epochs = iterations over corpus\n",
    "                 alpha=0.05,   # initial learning rate\n",
    "                 min_alpha=0.0001 # final learning rate\n",
    "                ) \n",
    "print('Duration:', format_time(time() - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Persist model & vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:58.739904Z",
     "start_time": "2020-06-21T15:24:58.508205Z"
    }
   },
   "outputs": [],
   "source": [
    "model.save((model_path / 'word2vec_0.model').as_posix())\n",
    "model.wv.save((model_path / 'word_vectors_0.bin').as_posix())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load model and vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.036488Z",
     "start_time": "2020-06-21T15:24:58.756668Z"
    }
   },
   "outputs": [],
   "source": [
    "model = Word2Vec.load((model_path / 'word2vec_0.model').as_posix())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.294149Z",
     "start_time": "2020-06-21T15:24:59.038043Z"
    }
   },
   "outputs": [],
   "source": [
    "wv = KeyedVectors.load((model_path / 'word_vectors_0.bin').as_posix())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.427873Z",
     "start_time": "2020-06-21T15:24:59.300032Z"
    }
   },
   "outputs": [],
   "source": [
    "vocab = []\n",
    "for k, _ in model.wv.vocab.items():\n",
    "    v_ = model.wv.vocab[k]\n",
    "    vocab.append([k, v_.index, v_.count])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.477941Z",
     "start_time": "2020-06-21T15:24:59.429275Z"
    }
   },
   "outputs": [],
   "source": [
    "vocab = (pd.DataFrame(vocab, \n",
    "                     columns=['token', 'idx', 'count'])\n",
    "         .sort_values('count', ascending=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.512055Z",
     "start_time": "2020-06-21T15:24:59.480511Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "Int64Index: 57384 entries, 715 to 25739\n",
      "Data columns (total 3 columns):\n",
      " #   Column  Non-Null Count  Dtype \n",
      "---  ------  --------------  ----- \n",
      " 0   token   57384 non-null  object\n",
      " 1   idx     57384 non-null  int64 \n",
      " 2   count   57384 non-null  int64 \n",
      "dtypes: int64(2), object(1)\n",
      "memory usage: 1.8+ MB\n"
     ]
    }
   ],
   "source": [
    "vocab.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.520332Z",
     "start_time": "2020-06-21T15:24:59.513350Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>idx</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>715</th>\n",
       "      <td>million</td>\n",
       "      <td>0</td>\n",
       "      <td>2340187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>business</td>\n",
       "      <td>1</td>\n",
       "      <td>1696732</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2029</th>\n",
       "      <td>december</td>\n",
       "      <td>2</td>\n",
       "      <td>1512367</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>company</td>\n",
       "      <td>3</td>\n",
       "      <td>1490617</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>171</th>\n",
       "      <td>products</td>\n",
       "      <td>4</td>\n",
       "      <td>1367413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2034</th>\n",
       "      <td>net</td>\n",
       "      <td>5</td>\n",
       "      <td>1246820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139</th>\n",
       "      <td>market</td>\n",
       "      <td>6</td>\n",
       "      <td>1148002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>350</th>\n",
       "      <td>including</td>\n",
       "      <td>7</td>\n",
       "      <td>1109821</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1623</th>\n",
       "      <td>sales</td>\n",
       "      <td>8</td>\n",
       "      <td>1095619</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2244</th>\n",
       "      <td>costs</td>\n",
       "      <td>9</td>\n",
       "      <td>1018821</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          token  idx    count\n",
       "715     million    0  2340187\n",
       "0      business    1  1696732\n",
       "2029   december    2  1512367\n",
       "6       company    3  1490617\n",
       "171    products    4  1367413\n",
       "2034        net    5  1246820\n",
       "139      market    6  1148002\n",
       "350   including    7  1109821\n",
       "1623      sales    8  1095619\n",
       "2244      costs    9  1018821"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.545436Z",
     "start_time": "2020-06-21T15:24:59.521405Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "count      57384\n",
       "mean        4523\n",
       "std        35191\n",
       "min           50\n",
       "10%           60\n",
       "20%           75\n",
       "30%           96\n",
       "40%          128\n",
       "50%          176\n",
       "60%          263\n",
       "70%          442\n",
       "80%          946\n",
       "90%         3666\n",
       "max      2340187\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab['count'].describe(percentiles=np.arange(.1, 1, .1)).astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate Analogies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:34:04.826106Z",
     "start_time": "2020-06-21T15:34:04.819570Z"
    }
   },
   "outputs": [],
   "source": [
    "def accuracy_by_category(acc, detail=True):\n",
    "    results = [[c['section'], len(c['correct']), len(c['incorrect'])] for c in acc]\n",
    "    results = pd.DataFrame(results, columns=['category', 'correct', 'incorrect'])\n",
    "    results['average'] = results.correct.div(results[['correct', 'incorrect']].sum(1))\n",
    "    if detail:\n",
    "        print(results.sort_values('average', ascending=False))\n",
    "    return results.loc[results.category=='total', ['correct', 'incorrect', 'average']].squeeze().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:34:58.842380Z",
     "start_time": "2020-06-21T15:34:05.044496Z"
    }
   },
   "outputs": [],
   "source": [
    "detailed_accuracy = model.wv.accuracy(analogies_path.as_posix(), case_insensitive=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:34:58.861949Z",
     "start_time": "2020-06-21T15:34:58.843367Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                       category  correct  incorrect   average\n",
      "10  gram6-nationality-adjective      367        445  0.451970\n",
      "12                 gram8-plural      228        372  0.380000\n",
      "7             gram3-comparative      260        610  0.298851\n",
      "14                        total     2083       5543  0.273145\n",
      "9      gram5-present-participle      100        280  0.263158\n",
      "2                 city-in-state      767       2151  0.262851\n",
      "13           gram9-plural-verbs       75        231  0.245098\n",
      "6                gram2-opposite       47        163  0.223810\n",
      "5     gram1-adjective-to-adverb       59        247  0.192810\n",
      "8             gram4-superlative       46        194  0.191667\n",
      "11             gram7-past-tense      119        531  0.183077\n",
      "4                        family        4         26  0.133333\n",
      "0      capital-common-countries        8        102  0.072727\n",
      "3                      currency        3        149  0.019737\n",
      "1                 capital-world        0         42  0.000000\n"
     ]
    }
   ],
   "source": [
    "summary = accuracy_by_category(detailed_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:34:58.879346Z",
     "start_time": "2020-06-21T15:34:58.863357Z"
    }
   },
   "outputs": [],
   "source": [
    "def eval_analogies(w2v, max_vocab=15000):\n",
    "    accuracy = w2v.wv.accuracy(analogies_path,\n",
    "                               restrict_vocab=15000,\n",
    "                               case_insensitive=True)\n",
    "    return (pd.DataFrame([[c['section'],\n",
    "                        len(c['correct']),\n",
    "                        len(c['incorrect'])] for c in accuracy],\n",
    "                      columns=['category', 'correct', 'incorrect'])\n",
    "          .assign(average=lambda x: \n",
    "                  x.correct.div(x.correct.add(x.incorrect))))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:34:58.883984Z",
     "start_time": "2020-06-21T15:34:58.880915Z"
    }
   },
   "outputs": [],
   "source": [
    "def total_accuracy(w2v):\n",
    "    df = eval_analogies(w2v)\n",
    "    return df.loc[df.category == 'total', ['correct', 'incorrect', 'average']].squeeze().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:35:16.240602Z",
     "start_time": "2020-06-21T15:34:58.885201Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>category</th>\n",
       "      <th>correct</th>\n",
       "      <th>incorrect</th>\n",
       "      <th>average</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>capital-common-countries</td>\n",
       "      <td>2</td>\n",
       "      <td>10</td>\n",
       "      <td>0.166667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>capital-world</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>city-in-state</td>\n",
       "      <td>231</td>\n",
       "      <td>493</td>\n",
       "      <td>0.319061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>currency</td>\n",
       "      <td>3</td>\n",
       "      <td>49</td>\n",
       "      <td>0.057692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>family</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>gram1-adjective-to-adverb</td>\n",
       "      <td>44</td>\n",
       "      <td>166</td>\n",
       "      <td>0.209524</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>gram2-opposite</td>\n",
       "      <td>17</td>\n",
       "      <td>73</td>\n",
       "      <td>0.188889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>gram3-comparative</td>\n",
       "      <td>234</td>\n",
       "      <td>366</td>\n",
       "      <td>0.390000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>gram4-superlative</td>\n",
       "      <td>22</td>\n",
       "      <td>68</td>\n",
       "      <td>0.244444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>gram5-present-participle</td>\n",
       "      <td>91</td>\n",
       "      <td>215</td>\n",
       "      <td>0.297386</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>gram6-nationality-adjective</td>\n",
       "      <td>268</td>\n",
       "      <td>194</td>\n",
       "      <td>0.580087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>gram7-past-tense</td>\n",
       "      <td>111</td>\n",
       "      <td>351</td>\n",
       "      <td>0.240260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>gram8-plural</td>\n",
       "      <td>96</td>\n",
       "      <td>86</td>\n",
       "      <td>0.527473</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>gram9-plural-verbs</td>\n",
       "      <td>74</td>\n",
       "      <td>136</td>\n",
       "      <td>0.352381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>total</td>\n",
       "      <td>1193</td>\n",
       "      <td>2207</td>\n",
       "      <td>0.350882</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       category  correct  incorrect   average\n",
       "0      capital-common-countries        2         10  0.166667\n",
       "1                 capital-world        0          0       NaN\n",
       "2                 city-in-state      231        493  0.319061\n",
       "3                      currency        3         49  0.057692\n",
       "4                        family        0          0       NaN\n",
       "5     gram1-adjective-to-adverb       44        166  0.209524\n",
       "6                gram2-opposite       17         73  0.188889\n",
       "7             gram3-comparative      234        366  0.390000\n",
       "8             gram4-superlative       22         68  0.244444\n",
       "9      gram5-present-participle       91        215  0.297386\n",
       "10  gram6-nationality-adjective      268        194  0.580087\n",
       "11             gram7-past-tense      111        351  0.240260\n",
       "12                 gram8-plural       96         86  0.527473\n",
       "13           gram9-plural-verbs       74        136  0.352381\n",
       "14                        total     1193       2207  0.350882"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy = eval_analogies(model)\n",
    "accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Validate Vector Arithmetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:35:16.256130Z",
     "start_time": "2020-06-21T15:35:16.241718Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "          term  similarity\n",
      "0         ipad    0.691182\n",
      "1      android    0.632260\n",
      "2          app    0.609227\n",
      "3   smartphone    0.605110\n",
      "4  smart_phone    0.580258\n",
      "5  smartphones    0.577489\n",
      "6     keyboard    0.559338\n",
      "7   mobile_app    0.525289\n",
      "8   downloaded    0.520682\n",
      "9           pc    0.511132\n"
     ]
    }
   ],
   "source": [
    "sims=model.wv.most_similar(positive=['iphone'], restrict_vocab=15000)\n",
    "print(pd.DataFrame(sims, columns=['term', 'similarity']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:35:16.279303Z",
     "start_time": "2020-06-21T15:35:16.257107Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             term  similarity\n",
      "0  united_kingdom    0.577512\n",
      "1         germany    0.561950\n",
      "2         belgium    0.542092\n",
      "3     netherlands    0.537289\n",
      "4       australia    0.515880\n",
      "5           italy    0.512862\n",
      "6          sweden    0.502201\n",
      "7           spain    0.496639\n",
      "8       singapore    0.495658\n",
      "9         austria    0.494052\n"
     ]
    }
   ],
   "source": [
    "analogy = model.wv.most_similar(positive=['france', 'london'], \n",
    "                                negative=['paris'], \n",
    "                                restrict_vocab=15000)\n",
    "print(pd.DataFrame(analogy, columns=['term', 'similarity']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check similarity for random words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:35:29.289819Z",
     "start_time": "2020-06-21T15:35:29.227163Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>paradigm</th>\n",
       "      <th>fundamentally</th>\n",
       "      <th>patient</th>\n",
       "      <th>quality</th>\n",
       "      <th>percent</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>paradigms</td>\n",
       "      <td>transform</td>\n",
       "      <td>patients</td>\n",
       "      <td>cleanliness</td>\n",
       "      <td>percent_percent</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>revolutionize</td>\n",
       "      <td>transforming</td>\n",
       "      <td>clinician</td>\n",
       "      <td>high_quality</td>\n",
       "      <td>percentage</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>personalized_medicine</td>\n",
       "      <td>profoundly</td>\n",
       "      <td>physician</td>\n",
       "      <td>originality</td>\n",
       "      <td>total</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>evolutionary</td>\n",
       "      <td>alter</td>\n",
       "      <td>outpatients</td>\n",
       "      <td>dependability</td>\n",
       "      <td>approximately</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>muscular_dystrophies</td>\n",
       "      <td>way</td>\n",
       "      <td>givers</td>\n",
       "      <td>timeliness</td>\n",
       "      <td>basis_points</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>adoptive</td>\n",
       "      <td>transformational</td>\n",
       "      <td>inpatient</td>\n",
       "      <td>consistency</td>\n",
       "      <td>mwh_mwh</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>radiance</td>\n",
       "      <td>revolutionized</td>\n",
       "      <td>hospital</td>\n",
       "      <td>trustworthiness</td>\n",
       "      <td>respectively</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>thyroid_nodule</td>\n",
       "      <td>reshape</td>\n",
       "      <td>protection_affordable</td>\n",
       "      <td>freshness</td>\n",
       "      <td>percentages</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>mx_icp</td>\n",
       "      <td>paradigms</td>\n",
       "      <td>ambulation</td>\n",
       "      <td>friendliness</td>\n",
       "      <td>rces</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>invasive_candida</td>\n",
       "      <td>converging</td>\n",
       "      <td>ophthalmologist</td>\n",
       "      <td>professionalism</td>\n",
       "      <td>wheel_rvs</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                paradigm     fundamentally                patient  \\\n",
       "0              paradigms         transform               patients   \n",
       "1          revolutionize      transforming              clinician   \n",
       "2  personalized_medicine        profoundly              physician   \n",
       "3           evolutionary             alter            outpatients   \n",
       "4   muscular_dystrophies               way                 givers   \n",
       "5               adoptive  transformational              inpatient   \n",
       "6               radiance    revolutionized               hospital   \n",
       "7         thyroid_nodule           reshape  protection_affordable   \n",
       "8                 mx_icp         paradigms             ambulation   \n",
       "9       invasive_candida        converging        ophthalmologist   \n",
       "\n",
       "           quality          percent  \n",
       "0      cleanliness  percent_percent  \n",
       "1     high_quality       percentage  \n",
       "2      originality            total  \n",
       "3    dependability    approximately  \n",
       "4       timeliness     basis_points  \n",
       "5      consistency          mwh_mwh  \n",
       "6  trustworthiness     respectively  \n",
       "7        freshness      percentages  \n",
       "8     friendliness             rces  \n",
       "9  professionalism        wheel_rvs  "
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "VALID_SET = 5  # Random set of words to get nearest neighbors for\n",
    "VALID_WINDOW = 100  # Most frequent words to draw validation set from\n",
    "valid_examples = np.random.choice(VALID_WINDOW, size=VALID_SET, replace=False)\n",
    "similars = pd.DataFrame()\n",
    "\n",
    "for id in sorted(valid_examples):\n",
    "    word = vocab.loc[id, 'token']\n",
    "    similars[word] = [s[0] for s in model.wv.most_similar(word)]\n",
    "similars"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continue Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-06-22T03:12:37.505Z"
    }
   },
   "outputs": [],
   "source": [
    "accuracies = [summary]\n",
    "best_accuracy = summary[-1]\n",
    "for i in range(1, 15):\n",
    "    start = time()\n",
    "    model.train(sentences, epochs=1, total_examples=model.corpus_count)\n",
    "    detailed_accuracy = model.wv.accuracy(analogies_path)\n",
    "    accuracies.append(accuracy_by_category(detailed_accuracy, detail=False))\n",
    "    print(f'{i:02} | Duration: {format_time(time() - start)} | Accuracy: {accuracies[-1][-1]:.2%} ')\n",
    "    if accuracies[-1][-1] > best_accuracy:\n",
    "        model.save((model_path / f'word2vec_{i:02}.model').as_posix())\n",
    "        model.wv.save((model_path / f'word_vectors_{i:02}.bin').as_posix())\n",
    "        best_accuracy = accuracies[-1][-1]\n",
    "    (pd.DataFrame(accuracies, \n",
    "                 columns=['correct', 'wrong', 'average'])\n",
    "     .to_csv(model_path / 'accuracies.csv', index=False))\n",
    "model.wv.save((model_path / 'word_vectors_final.bin').as_posix())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample Output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "|Epoch|Duration| Accuracy|\n",
    "|---|---|---|\n",
    "01 | 00:14:00 | 31.64% | \n",
    "02 | 00:14:21 | 31.72% | \n",
    "03 | 00:14:34 | 33.65% | \n",
    "04 | 00:16:11 | 34.03% | \n",
    "05 | 00:13:51 | 33.04% | \n",
    "06 | 00:13:46 | 33.28% | \n",
    "07 | 00:13:51 | 33.10% | \n",
    "08 | 00:13:54 | 34.11% | \n",
    "09 | 00:13:54 | 33.70% | \n",
    "10 | 00:13:55 | 34.09% | \n",
    "11 | 00:13:57 | 35.06% | \n",
    "12 | 00:13:38 | 33.79% | \n",
    "13 | 00:13:26 | 32.40% | "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-06-22T03:12:50.880Z"
    }
   },
   "outputs": [],
   "source": [
    "(pd.DataFrame(accuracies, \n",
    "             columns=['correct', 'wrong', 'average'])\n",
    " .to_csv(results_path / 'accuracies.csv', index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.759064Z",
     "start_time": "2020-06-21T14:57:52.725Z"
    }
   },
   "outputs": [],
   "source": [
    "best_model = Word2Vec.load((results_path / 'word2vec_11.model').as_posix())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.759787Z",
     "start_time": "2020-06-21T14:57:52.726Z"
    }
   },
   "outputs": [],
   "source": [
    "detailed_accuracy = best_model.wv.accuracy(analogies_path.as_posix(), case_insensitive=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.760326Z",
     "start_time": "2020-06-21T14:57:52.728Z"
    }
   },
   "outputs": [],
   "source": [
    "summary = accuracy_by_category(detailed_accuracy)\n",
    "print('Base Accuracy: Correct {:,.0f} | Wrong {:,.0f} | Avg {:,.2%}\\n'.format(*summary))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.760862Z",
     "start_time": "2020-06-21T14:57:52.730Z"
    }
   },
   "outputs": [],
   "source": [
    "cat_dict = {'capital-common-countries':'Capitals',\n",
    "            'capital-world':'Capitals RoW',\n",
    "            'city-in-state':'City-State',\n",
    "            'currency':'Currency',\n",
    "            'family':'Famliy',\n",
    "            'gram1-adjective-to-adverb':'Adj-Adverb',\n",
    "            'gram2-opposite':'Opposite',\n",
    "            'gram3-comparative':'Comparative',\n",
    "            'gram4-superlative':'Superlative',\n",
    "            'gram5-present-participle':'Pres. Part.',\n",
    "            'gram6-nationality-adjective':'Nationality',\n",
    "            'gram7-past-tense':'Past Tense',\n",
    "            'gram8-plural':'Plural',\n",
    "            'gram9-plural-verbs':'Plural Verbs',\n",
    "            'total':'Total'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.761465Z",
     "start_time": "2020-06-21T14:57:52.732Z"
    }
   },
   "outputs": [],
   "source": [
    "results = [[c['section'], len(c['correct']), len(c['incorrect'])] for c in detailed_accuracy]\n",
    "results = pd.DataFrame(results, columns=['category', 'correct', 'incorrect'])\n",
    "results['category'] = results.category.map(cat_dict)\n",
    "results['average'] = results.correct.div(results[['correct', 'incorrect']].sum(1))\n",
    "results = results.rename(columns=str.capitalize).set_index('Category')\n",
    "total = results.loc['Total']\n",
    "results = results.drop('Total')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.762031Z",
     "start_time": "2020-06-21T14:57:52.734Z"
    }
   },
   "outputs": [],
   "source": [
    "most_sim = best_model.wv.most_similar(positive=['woman', 'king'], negative=['man'], topn=20)\n",
    "pd.DataFrame(most_sim, columns=['token', 'similarity'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-21T15:24:59.762602Z",
     "start_time": "2020-06-21T14:57:52.735Z"
    }
   },
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(figsize=(16, 5), ncols=2)\n",
    "\n",
    "axes[0] = results.loc[:, ['Correct', 'Incorrect']].plot.bar(stacked=True, ax=axes[0]\n",
    "                                                           , title='Analogy Accuracy')\n",
    "ax1 = results.loc[:, ['Average']].plot(ax=axes[0], secondary_y=True, lw=1, c='k', rot=35)\n",
    "ax1.yaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.0%}'.format(y)))\n",
    "\n",
    "(pd.DataFrame(most_sim, columns=['token', 'similarity'])\n",
    " .set_index('token').similarity\n",
    " .sort_values().tail(10).plot.barh(xlim=(.3, .37), ax=axes[1], title='Closest matches for Woman + King - Man'))\n",
    "fig.tight_layout()\n",
    "# fig.savefig('figures/w2v_sec_evaluation', dpi=300);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
